{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### This notebook demonstrates the use of adversarial debiasing algorithm to learn a fair classifier.\n",
    "Adversarial debiasing [1] is an in-processing technique that learns a classifier to maximize prediction accuracy and simultaneously reduce an adversary's ability to determine the protected attribute from the predictions. This approach leads to a fair classifier as the predictions cannot carry any group discrimination information that the adversary can exploit. We will see how to use this algorithm for learning models with and without fairness constraints and apply them on the Adult dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/anaconda2/lib/python2.7/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "# Load all necessary packages\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from aif360.datasets import BinaryLabelDataset\n",
    "from aif360.datasets import AdultDataset, GermanDataset, CompasDataset\n",
    "from aif360.metrics import BinaryLabelDatasetMetric\n",
    "from aif360.metrics import ClassificationMetric\n",
    "from aif360.metrics.utils import compute_boolean_conditioning_vector\n",
    "\n",
    "from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_adult, load_preproc_data_compas, load_preproc_data_german\n",
    "\n",
    "from aif360.algorithms.inprocessing.adversarial_debiasing import AdversarialDebiasing\n",
    "\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.preprocessing import StandardScaler, MaxAbsScaler\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "from IPython.display import Markdown, display\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import tensorflow.compat.v1 as tf\n",
    "tf.disable_eager_execution()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load dataset and set options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the dataset and split into train and test\n",
    "dataset_orig = load_preproc_data_adult()\n",
    "\n",
    "privileged_groups = [{'sex': 1}]\n",
    "unprivileged_groups = [{'sex': 0}]\n",
    "\n",
    "dataset_orig_train, dataset_orig_test = dataset_orig.split([0.7], shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "#### Training Dataset shape"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(34189, 18)\n"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "#### Favorable and unfavorable labels"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1.0, 0.0)\n"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "#### Protected attribute names"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['sex', 'race']\n"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "#### Privileged and unprivileged protected attribute values"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "([array([1.]), array([1.])], [array([0.]), array([0.])])\n"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "#### Dataset feature names"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['race', 'sex', 'Age (decade)=10', 'Age (decade)=20', 'Age (decade)=30', 'Age (decade)=40', 'Age (decade)=50', 'Age (decade)=60', 'Age (decade)=>=70', 'Education Years=6', 'Education Years=7', 'Education Years=8', 'Education Years=9', 'Education Years=10', 'Education Years=11', 'Education Years=12', 'Education Years=<6', 'Education Years=>12']\n"
     ]
    }
   ],
   "source": [
    "# print out some labels, names, etc.\n",
    "display(Markdown(\"#### Training Dataset shape\"))\n",
    "print(dataset_orig_train.features.shape)\n",
    "display(Markdown(\"#### Favorable and unfavorable labels\"))\n",
    "print(dataset_orig_train.favorable_label, dataset_orig_train.unfavorable_label)\n",
    "display(Markdown(\"#### Protected attribute names\"))\n",
    "print(dataset_orig_train.protected_attribute_names)\n",
    "display(Markdown(\"#### Privileged and unprivileged protected attribute values\"))\n",
    "print(dataset_orig_train.privileged_protected_attributes, \n",
    "      dataset_orig_train.unprivileged_protected_attributes)\n",
    "display(Markdown(\"#### Dataset feature names\"))\n",
    "print(dataset_orig_train.feature_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Metric for original training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "#### Original training dataset"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train set: Difference in mean outcomes between unprivileged and privileged groups = -0.192750\n",
      "Test set: Difference in mean outcomes between unprivileged and privileged groups = -0.198626\n"
     ]
    }
   ],
   "source": [
    "# Metric for the original dataset\n",
    "metric_orig_train = BinaryLabelDatasetMetric(dataset_orig_train, \n",
    "                                             unprivileged_groups=unprivileged_groups,\n",
    "                                             privileged_groups=privileged_groups)\n",
    "display(Markdown(\"#### Original training dataset\"))\n",
    "print(\"Train set: Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_orig_train.mean_difference())\n",
    "metric_orig_test = BinaryLabelDatasetMetric(dataset_orig_test, \n",
    "                                             unprivileged_groups=unprivileged_groups,\n",
    "                                             privileged_groups=privileged_groups)\n",
    "print(\"Test set: Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_orig_test.mean_difference())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "#### Scaled dataset - Verify that the scaling does not affect the group label statistics"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train set: Difference in mean outcomes between unprivileged and privileged groups = -0.192750\n",
      "Test set: Difference in mean outcomes between unprivileged and privileged groups = -0.198626\n"
     ]
    }
   ],
   "source": [
    "min_max_scaler = MaxAbsScaler()\n",
    "dataset_orig_train.features = min_max_scaler.fit_transform(dataset_orig_train.features)\n",
    "dataset_orig_test.features = min_max_scaler.transform(dataset_orig_test.features)\n",
    "metric_scaled_train = BinaryLabelDatasetMetric(dataset_orig_train, \n",
    "                             unprivileged_groups=unprivileged_groups,\n",
    "                             privileged_groups=privileged_groups)\n",
    "display(Markdown(\"#### Scaled dataset - Verify that the scaling does not affect the group label statistics\"))\n",
    "print(\"Train set: Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_scaled_train.mean_difference())\n",
    "metric_scaled_test = BinaryLabelDatasetMetric(dataset_orig_test, \n",
    "                             unprivileged_groups=unprivileged_groups,\n",
    "                             privileged_groups=privileged_groups)\n",
    "print(\"Test set: Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_scaled_test.mean_difference())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Learn plan classifier without debiasing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load post-processing algorithm that equalizes the odds\n",
    "# Learn parameters with debias set to False\n",
    "sess = tf.Session()\n",
    "plain_model = AdversarialDebiasing(privileged_groups = privileged_groups,\n",
    "                          unprivileged_groups = unprivileged_groups,\n",
    "                          scope_name='plain_classifier',\n",
    "                          debias=False,\n",
    "                          sess=sess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0; iter: 0; batch classifier loss: 0.707587\n",
      "epoch 0; iter: 200; batch classifier loss: 0.396519\n",
      "epoch 1; iter: 0; batch classifier loss: 0.450665\n",
      "epoch 1; iter: 200; batch classifier loss: 0.439213\n",
      "epoch 2; iter: 0; batch classifier loss: 0.495045\n",
      "epoch 2; iter: 200; batch classifier loss: 0.513713\n",
      "epoch 3; iter: 0; batch classifier loss: 0.349774\n",
      "epoch 3; iter: 200; batch classifier loss: 0.380733\n",
      "epoch 4; iter: 0; batch classifier loss: 0.345100\n",
      "epoch 4; iter: 200; batch classifier loss: 0.399097\n",
      "epoch 5; iter: 0; batch classifier loss: 0.423275\n",
      "epoch 5; iter: 200; batch classifier loss: 0.418846\n",
      "epoch 6; iter: 0; batch classifier loss: 0.411661\n",
      "epoch 6; iter: 200; batch classifier loss: 0.357504\n",
      "epoch 7; iter: 0; batch classifier loss: 0.404039\n",
      "epoch 7; iter: 200; batch classifier loss: 0.447010\n",
      "epoch 8; iter: 0; batch classifier loss: 0.417079\n",
      "epoch 8; iter: 200; batch classifier loss: 0.513713\n",
      "epoch 9; iter: 0; batch classifier loss: 0.299503\n",
      "epoch 9; iter: 200; batch classifier loss: 0.447425\n",
      "epoch 10; iter: 0; batch classifier loss: 0.513632\n",
      "epoch 10; iter: 200; batch classifier loss: 0.376522\n",
      "epoch 11; iter: 0; batch classifier loss: 0.539716\n",
      "epoch 11; iter: 200; batch classifier loss: 0.449963\n",
      "epoch 12; iter: 0; batch classifier loss: 0.429474\n",
      "epoch 12; iter: 200; batch classifier loss: 0.408618\n",
      "epoch 13; iter: 0; batch classifier loss: 0.440729\n",
      "epoch 13; iter: 200; batch classifier loss: 0.464946\n",
      "epoch 14; iter: 0; batch classifier loss: 0.390707\n",
      "epoch 14; iter: 200; batch classifier loss: 0.482682\n",
      "epoch 15; iter: 0; batch classifier loss: 0.352653\n",
      "epoch 15; iter: 200; batch classifier loss: 0.423660\n",
      "epoch 16; iter: 0; batch classifier loss: 0.424234\n",
      "epoch 16; iter: 200; batch classifier loss: 0.390729\n",
      "epoch 17; iter: 0; batch classifier loss: 0.411589\n",
      "epoch 17; iter: 200; batch classifier loss: 0.389220\n",
      "epoch 18; iter: 0; batch classifier loss: 0.331668\n",
      "epoch 18; iter: 200; batch classifier loss: 0.384711\n",
      "epoch 19; iter: 0; batch classifier loss: 0.353290\n",
      "epoch 19; iter: 200; batch classifier loss: 0.457664\n",
      "epoch 20; iter: 0; batch classifier loss: 0.356439\n",
      "epoch 20; iter: 200; batch classifier loss: 0.334217\n",
      "epoch 21; iter: 0; batch classifier loss: 0.438827\n",
      "epoch 21; iter: 200; batch classifier loss: 0.382024\n",
      "epoch 22; iter: 0; batch classifier loss: 0.420756\n",
      "epoch 22; iter: 200; batch classifier loss: 0.374907\n",
      "epoch 23; iter: 0; batch classifier loss: 0.475280\n",
      "epoch 23; iter: 200; batch classifier loss: 0.426664\n",
      "epoch 24; iter: 0; batch classifier loss: 0.351704\n",
      "epoch 24; iter: 200; batch classifier loss: 0.361529\n",
      "epoch 25; iter: 0; batch classifier loss: 0.411303\n",
      "epoch 25; iter: 200; batch classifier loss: 0.487325\n",
      "epoch 26; iter: 0; batch classifier loss: 0.407306\n",
      "epoch 26; iter: 200; batch classifier loss: 0.484252\n",
      "epoch 27; iter: 0; batch classifier loss: 0.364663\n",
      "epoch 27; iter: 200; batch classifier loss: 0.455063\n",
      "epoch 28; iter: 0; batch classifier loss: 0.434696\n",
      "epoch 28; iter: 200; batch classifier loss: 0.449683\n",
      "epoch 29; iter: 0; batch classifier loss: 0.418321\n",
      "epoch 29; iter: 200; batch classifier loss: 0.434468\n",
      "epoch 30; iter: 0; batch classifier loss: 0.409858\n",
      "epoch 30; iter: 200; batch classifier loss: 0.466626\n",
      "epoch 31; iter: 0; batch classifier loss: 0.450511\n",
      "epoch 31; iter: 200; batch classifier loss: 0.450152\n",
      "epoch 32; iter: 0; batch classifier loss: 0.465642\n",
      "epoch 32; iter: 200; batch classifier loss: 0.428328\n",
      "epoch 33; iter: 0; batch classifier loss: 0.392987\n",
      "epoch 33; iter: 200; batch classifier loss: 0.373837\n",
      "epoch 34; iter: 0; batch classifier loss: 0.448555\n",
      "epoch 34; iter: 200; batch classifier loss: 0.485128\n",
      "epoch 35; iter: 0; batch classifier loss: 0.344462\n",
      "epoch 35; iter: 200; batch classifier loss: 0.388613\n",
      "epoch 36; iter: 0; batch classifier loss: 0.466822\n",
      "epoch 36; iter: 200; batch classifier loss: 0.363230\n",
      "epoch 37; iter: 0; batch classifier loss: 0.440089\n",
      "epoch 37; iter: 200; batch classifier loss: 0.382196\n",
      "epoch 38; iter: 0; batch classifier loss: 0.386720\n",
      "epoch 38; iter: 200; batch classifier loss: 0.447435\n",
      "epoch 39; iter: 0; batch classifier loss: 0.384074\n",
      "epoch 39; iter: 200; batch classifier loss: 0.394575\n",
      "epoch 40; iter: 0; batch classifier loss: 0.378215\n",
      "epoch 40; iter: 200; batch classifier loss: 0.421163\n",
      "epoch 41; iter: 0; batch classifier loss: 0.387049\n",
      "epoch 41; iter: 200; batch classifier loss: 0.392461\n",
      "epoch 42; iter: 0; batch classifier loss: 0.392354\n",
      "epoch 42; iter: 200; batch classifier loss: 0.413999\n",
      "epoch 43; iter: 0; batch classifier loss: 0.447966\n",
      "epoch 43; iter: 200; batch classifier loss: 0.417566\n",
      "epoch 44; iter: 0; batch classifier loss: 0.507449\n",
      "epoch 44; iter: 200; batch classifier loss: 0.407887\n",
      "epoch 45; iter: 0; batch classifier loss: 0.396286\n",
      "epoch 45; iter: 200; batch classifier loss: 0.390399\n",
      "epoch 46; iter: 0; batch classifier loss: 0.418439\n",
      "epoch 46; iter: 200; batch classifier loss: 0.380013\n",
      "epoch 47; iter: 0; batch classifier loss: 0.407893\n",
      "epoch 47; iter: 200; batch classifier loss: 0.433631\n",
      "epoch 48; iter: 0; batch classifier loss: 0.461974\n",
      "epoch 48; iter: 200; batch classifier loss: 0.447301\n",
      "epoch 49; iter: 0; batch classifier loss: 0.356089\n",
      "epoch 49; iter: 200; batch classifier loss: 0.467275\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<aif360.algorithms.inprocessing.adversarial_debiasing.AdversarialDebiasing at 0x106cc2f10>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "plain_model.fit(dataset_orig_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Apply the plain model to test data\n",
    "dataset_nodebiasing_train = plain_model.predict(dataset_orig_train)\n",
    "dataset_nodebiasing_test = plain_model.predict(dataset_orig_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "#### Plain model - without debiasing - dataset metrics"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train set: Difference in mean outcomes between unprivileged and privileged groups = -0.217876\n",
      "Test set: Difference in mean outcomes between unprivileged and privileged groups = -0.221187\n"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "#### Plain model - without debiasing - classification metrics"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test set: Classification accuracy = 0.804955\n",
      "Test set: Balanced classification accuracy = 0.666400\n",
      "Test set: Disparate impact = 0.000000\n",
      "Test set: Equal opportunity difference = -0.470687\n",
      "Test set: Average odds difference = -0.291055\n",
      "Test set: Theil_index = 0.175113\n"
     ]
    }
   ],
   "source": [
    "# Metrics for the dataset from plain model (without debiasing)\n",
    "display(Markdown(\"#### Plain model - without debiasing - dataset metrics\"))\n",
    "metric_dataset_nodebiasing_train = BinaryLabelDatasetMetric(dataset_nodebiasing_train, \n",
    "                                             unprivileged_groups=unprivileged_groups,\n",
    "                                             privileged_groups=privileged_groups)\n",
    "\n",
    "print(\"Train set: Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_dataset_nodebiasing_train.mean_difference())\n",
    "\n",
    "metric_dataset_nodebiasing_test = BinaryLabelDatasetMetric(dataset_nodebiasing_test, \n",
    "                                             unprivileged_groups=unprivileged_groups,\n",
    "                                             privileged_groups=privileged_groups)\n",
    "\n",
    "print(\"Test set: Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_dataset_nodebiasing_test.mean_difference())\n",
    "\n",
    "display(Markdown(\"#### Plain model - without debiasing - classification metrics\"))\n",
    "classified_metric_nodebiasing_test = ClassificationMetric(dataset_orig_test, \n",
    "                                                 dataset_nodebiasing_test,\n",
    "                                                 unprivileged_groups=unprivileged_groups,\n",
    "                                                 privileged_groups=privileged_groups)\n",
    "print(\"Test set: Classification accuracy = %f\" % classified_metric_nodebiasing_test.accuracy())\n",
    "TPR = classified_metric_nodebiasing_test.true_positive_rate()\n",
    "TNR = classified_metric_nodebiasing_test.true_negative_rate()\n",
    "bal_acc_nodebiasing_test = 0.5*(TPR+TNR)\n",
    "print(\"Test set: Balanced classification accuracy = %f\" % bal_acc_nodebiasing_test)\n",
    "print(\"Test set: Disparate impact = %f\" % classified_metric_nodebiasing_test.disparate_impact())\n",
    "print(\"Test set: Equal opportunity difference = %f\" % classified_metric_nodebiasing_test.equal_opportunity_difference())\n",
    "print(\"Test set: Average odds difference = %f\" % classified_metric_nodebiasing_test.average_odds_difference())\n",
    "print(\"Test set: Theil_index = %f\" % classified_metric_nodebiasing_test.theil_index())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Apply in-processing algorithm based on adversarial learning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess.close()\n",
    "tf.reset_default_graph()\n",
    "sess = tf.Session()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Learn parameters with debias set to True\n",
    "debiased_model = AdversarialDebiasing(privileged_groups = privileged_groups,\n",
    "                          unprivileged_groups = unprivileged_groups,\n",
    "                          scope_name='debiased_classifier',\n",
    "                          debias=True,\n",
    "                          sess=sess)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0; iter: 0; batch classifier loss: 0.721611; batch adversarial loss: 0.630777\n",
      "epoch 0; iter: 200; batch classifier loss: 0.442980; batch adversarial loss: 0.656542\n",
      "epoch 1; iter: 0; batch classifier loss: 0.453149; batch adversarial loss: 0.657557\n",
      "epoch 1; iter: 200; batch classifier loss: 0.496931; batch adversarial loss: 0.617686\n",
      "epoch 2; iter: 0; batch classifier loss: 0.547117; batch adversarial loss: 0.653103\n",
      "epoch 2; iter: 200; batch classifier loss: 0.331452; batch adversarial loss: 0.617297\n",
      "epoch 3; iter: 0; batch classifier loss: 0.407935; batch adversarial loss: 0.627860\n",
      "epoch 3; iter: 200; batch classifier loss: 0.413469; batch adversarial loss: 0.616086\n",
      "epoch 4; iter: 0; batch classifier loss: 0.370982; batch adversarial loss: 0.604738\n",
      "epoch 4; iter: 200; batch classifier loss: 0.469453; batch adversarial loss: 0.617892\n",
      "epoch 5; iter: 0; batch classifier loss: 0.502638; batch adversarial loss: 0.595247\n",
      "epoch 5; iter: 200; batch classifier loss: 0.379807; batch adversarial loss: 0.635309\n",
      "epoch 6; iter: 0; batch classifier loss: 0.484228; batch adversarial loss: 0.591971\n",
      "epoch 6; iter: 200; batch classifier loss: 0.421526; batch adversarial loss: 0.612220\n",
      "epoch 7; iter: 0; batch classifier loss: 0.392731; batch adversarial loss: 0.636230\n",
      "epoch 7; iter: 200; batch classifier loss: 0.391191; batch adversarial loss: 0.614543\n",
      "epoch 8; iter: 0; batch classifier loss: 0.481106; batch adversarial loss: 0.665277\n",
      "epoch 8; iter: 200; batch classifier loss: 0.465566; batch adversarial loss: 0.638007\n",
      "epoch 9; iter: 0; batch classifier loss: 0.418696; batch adversarial loss: 0.575338\n",
      "epoch 9; iter: 200; batch classifier loss: 0.450467; batch adversarial loss: 0.611760\n",
      "epoch 10; iter: 0; batch classifier loss: 0.347429; batch adversarial loss: 0.652913\n",
      "epoch 10; iter: 200; batch classifier loss: 0.398043; batch adversarial loss: 0.601349\n",
      "epoch 11; iter: 0; batch classifier loss: 0.478137; batch adversarial loss: 0.630736\n",
      "epoch 11; iter: 200; batch classifier loss: 0.479216; batch adversarial loss: 0.552848\n",
      "epoch 12; iter: 0; batch classifier loss: 0.473339; batch adversarial loss: 0.597487\n",
      "epoch 12; iter: 200; batch classifier loss: 0.378719; batch adversarial loss: 0.630144\n",
      "epoch 13; iter: 0; batch classifier loss: 0.461751; batch adversarial loss: 0.583154\n",
      "epoch 13; iter: 200; batch classifier loss: 0.427811; batch adversarial loss: 0.594790\n",
      "epoch 14; iter: 0; batch classifier loss: 0.520254; batch adversarial loss: 0.586920\n",
      "epoch 14; iter: 200; batch classifier loss: 0.375389; batch adversarial loss: 0.622141\n",
      "epoch 15; iter: 0; batch classifier loss: 0.358494; batch adversarial loss: 0.610482\n",
      "epoch 15; iter: 200; batch classifier loss: 0.377246; batch adversarial loss: 0.600464\n",
      "epoch 16; iter: 0; batch classifier loss: 0.330568; batch adversarial loss: 0.631124\n",
      "epoch 16; iter: 200; batch classifier loss: 0.493238; batch adversarial loss: 0.602217\n",
      "epoch 17; iter: 0; batch classifier loss: 0.430809; batch adversarial loss: 0.622507\n",
      "epoch 17; iter: 200; batch classifier loss: 0.420727; batch adversarial loss: 0.631383\n",
      "epoch 18; iter: 0; batch classifier loss: 0.463418; batch adversarial loss: 0.613122\n",
      "epoch 18; iter: 200; batch classifier loss: 0.407586; batch adversarial loss: 0.583201\n",
      "epoch 19; iter: 0; batch classifier loss: 0.438854; batch adversarial loss: 0.588028\n",
      "epoch 19; iter: 200; batch classifier loss: 0.468554; batch adversarial loss: 0.586143\n",
      "epoch 20; iter: 0; batch classifier loss: 0.491485; batch adversarial loss: 0.627042\n",
      "epoch 20; iter: 200; batch classifier loss: 0.434700; batch adversarial loss: 0.629269\n",
      "epoch 21; iter: 0; batch classifier loss: 0.445875; batch adversarial loss: 0.589738\n",
      "epoch 21; iter: 200; batch classifier loss: 0.435593; batch adversarial loss: 0.629081\n",
      "epoch 22; iter: 0; batch classifier loss: 0.364423; batch adversarial loss: 0.610640\n",
      "epoch 22; iter: 200; batch classifier loss: 0.389425; batch adversarial loss: 0.605668\n",
      "epoch 23; iter: 0; batch classifier loss: 0.562680; batch adversarial loss: 0.634945\n",
      "epoch 23; iter: 200; batch classifier loss: 0.473808; batch adversarial loss: 0.566636\n",
      "epoch 24; iter: 0; batch classifier loss: 0.424366; batch adversarial loss: 0.585584\n",
      "epoch 24; iter: 200; batch classifier loss: 0.359588; batch adversarial loss: 0.609465\n",
      "epoch 25; iter: 0; batch classifier loss: 0.519477; batch adversarial loss: 0.564588\n",
      "epoch 25; iter: 200; batch classifier loss: 0.449761; batch adversarial loss: 0.571238\n",
      "epoch 26; iter: 0; batch classifier loss: 0.447675; batch adversarial loss: 0.591839\n",
      "epoch 26; iter: 200; batch classifier loss: 0.369251; batch adversarial loss: 0.580864\n",
      "epoch 27; iter: 0; batch classifier loss: 0.384472; batch adversarial loss: 0.661156\n",
      "epoch 27; iter: 200; batch classifier loss: 0.393334; batch adversarial loss: 0.638825\n",
      "epoch 28; iter: 0; batch classifier loss: 0.451982; batch adversarial loss: 0.552013\n",
      "epoch 28; iter: 200; batch classifier loss: 0.399544; batch adversarial loss: 0.612651\n",
      "epoch 29; iter: 0; batch classifier loss: 0.390971; batch adversarial loss: 0.580380\n",
      "epoch 29; iter: 200; batch classifier loss: 0.401580; batch adversarial loss: 0.582367\n",
      "epoch 30; iter: 0; batch classifier loss: 0.297665; batch adversarial loss: 0.547717\n",
      "epoch 30; iter: 200; batch classifier loss: 0.470934; batch adversarial loss: 0.625385\n",
      "epoch 31; iter: 0; batch classifier loss: 0.418402; batch adversarial loss: 0.622812\n",
      "epoch 31; iter: 200; batch classifier loss: 0.385281; batch adversarial loss: 0.603873\n",
      "epoch 32; iter: 0; batch classifier loss: 0.418848; batch adversarial loss: 0.573049\n",
      "epoch 32; iter: 200; batch classifier loss: 0.443066; batch adversarial loss: 0.621068\n",
      "epoch 33; iter: 0; batch classifier loss: 0.461614; batch adversarial loss: 0.606992\n",
      "epoch 33; iter: 200; batch classifier loss: 0.451093; batch adversarial loss: 0.621659\n",
      "epoch 34; iter: 0; batch classifier loss: 0.407544; batch adversarial loss: 0.646782\n",
      "epoch 34; iter: 200; batch classifier loss: 0.441481; batch adversarial loss: 0.645866\n",
      "epoch 35; iter: 0; batch classifier loss: 0.344949; batch adversarial loss: 0.589151\n",
      "epoch 35; iter: 200; batch classifier loss: 0.387160; batch adversarial loss: 0.549727\n",
      "epoch 36; iter: 0; batch classifier loss: 0.432171; batch adversarial loss: 0.675994\n",
      "epoch 36; iter: 200; batch classifier loss: 0.388955; batch adversarial loss: 0.621595\n",
      "epoch 37; iter: 0; batch classifier loss: 0.443978; batch adversarial loss: 0.658480\n",
      "epoch 37; iter: 200; batch classifier loss: 0.422210; batch adversarial loss: 0.617039\n",
      "epoch 38; iter: 0; batch classifier loss: 0.381281; batch adversarial loss: 0.588504\n",
      "epoch 38; iter: 200; batch classifier loss: 0.323892; batch adversarial loss: 0.596638\n",
      "epoch 39; iter: 0; batch classifier loss: 0.396359; batch adversarial loss: 0.614882\n",
      "epoch 39; iter: 200; batch classifier loss: 0.473418; batch adversarial loss: 0.562516\n",
      "epoch 40; iter: 0; batch classifier loss: 0.415690; batch adversarial loss: 0.617672\n",
      "epoch 40; iter: 200; batch classifier loss: 0.472975; batch adversarial loss: 0.537192\n",
      "epoch 41; iter: 0; batch classifier loss: 0.473487; batch adversarial loss: 0.591801\n",
      "epoch 41; iter: 200; batch classifier loss: 0.379132; batch adversarial loss: 0.602665\n",
      "epoch 42; iter: 0; batch classifier loss: 0.418546; batch adversarial loss: 0.568511\n",
      "epoch 42; iter: 200; batch classifier loss: 0.366345; batch adversarial loss: 0.603213\n",
      "epoch 43; iter: 0; batch classifier loss: 0.364993; batch adversarial loss: 0.596730\n",
      "epoch 43; iter: 200; batch classifier loss: 0.436417; batch adversarial loss: 0.611999\n",
      "epoch 44; iter: 0; batch classifier loss: 0.419406; batch adversarial loss: 0.602352\n",
      "epoch 44; iter: 200; batch classifier loss: 0.472369; batch adversarial loss: 0.592246\n",
      "epoch 45; iter: 0; batch classifier loss: 0.479547; batch adversarial loss: 0.564802\n",
      "epoch 45; iter: 200; batch classifier loss: 0.476123; batch adversarial loss: 0.603599\n",
      "epoch 46; iter: 0; batch classifier loss: 0.546357; batch adversarial loss: 0.631894\n",
      "epoch 46; iter: 200; batch classifier loss: 0.389170; batch adversarial loss: 0.576345\n",
      "epoch 47; iter: 0; batch classifier loss: 0.480703; batch adversarial loss: 0.603182\n",
      "epoch 47; iter: 200; batch classifier loss: 0.586694; batch adversarial loss: 0.635715\n",
      "epoch 48; iter: 0; batch classifier loss: 0.394101; batch adversarial loss: 0.558852\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 48; iter: 200; batch classifier loss: 0.453874; batch adversarial loss: 0.602889\n",
      "epoch 49; iter: 0; batch classifier loss: 0.506737; batch adversarial loss: 0.624289\n",
      "epoch 49; iter: 200; batch classifier loss: 0.359482; batch adversarial loss: 0.618086\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<aif360.algorithms.inprocessing.adversarial_debiasing.AdversarialDebiasing at 0x1c32efcf10>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "debiased_model.fit(dataset_orig_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Apply the plain model to test data\n",
    "dataset_debiasing_train = debiased_model.predict(dataset_orig_train)\n",
    "dataset_debiasing_test = debiased_model.predict(dataset_orig_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "#### Plain model - without debiasing - dataset metrics"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train set: Difference in mean outcomes between unprivileged and privileged groups = -0.217876\n",
      "Test set: Difference in mean outcomes between unprivileged and privileged groups = -0.221187\n"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "#### Model - with debiasing - dataset metrics"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train set: Difference in mean outcomes between unprivileged and privileged groups = -0.090157\n",
      "Test set: Difference in mean outcomes between unprivileged and privileged groups = -0.094732\n"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "#### Plain model - without debiasing - classification metrics"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test set: Classification accuracy = 0.804955\n",
      "Test set: Balanced classification accuracy = 0.666400\n",
      "Test set: Disparate impact = 0.000000\n",
      "Test set: Equal opportunity difference = -0.470687\n",
      "Test set: Average odds difference = -0.291055\n",
      "Test set: Theil_index = 0.175113\n"
     ]
    },
    {
     "data": {
      "text/markdown": [
       "#### Model - with debiasing - classification metrics"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test set: Classification accuracy = 0.792056\n",
      "Test set: Balanced classification accuracy = 0.672481\n",
      "Test set: Disparate impact = 0.553746\n",
      "Test set: Equal opportunity difference = -0.090716\n",
      "Test set: Average odds difference = -0.053841\n",
      "Test set: Theil_index = 0.170358\n"
     ]
    }
   ],
   "source": [
    "# Metrics for the dataset from plain model (without debiasing)\n",
    "display(Markdown(\"#### Plain model - without debiasing - dataset metrics\"))\n",
    "print(\"Train set: Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_dataset_nodebiasing_train.mean_difference())\n",
    "print(\"Test set: Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_dataset_nodebiasing_test.mean_difference())\n",
    "\n",
    "# Metrics for the dataset from model with debiasing\n",
    "display(Markdown(\"#### Model - with debiasing - dataset metrics\"))\n",
    "metric_dataset_debiasing_train = BinaryLabelDatasetMetric(dataset_debiasing_train, \n",
    "                                             unprivileged_groups=unprivileged_groups,\n",
    "                                             privileged_groups=privileged_groups)\n",
    "\n",
    "print(\"Train set: Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_dataset_debiasing_train.mean_difference())\n",
    "\n",
    "metric_dataset_debiasing_test = BinaryLabelDatasetMetric(dataset_debiasing_test, \n",
    "                                             unprivileged_groups=unprivileged_groups,\n",
    "                                             privileged_groups=privileged_groups)\n",
    "\n",
    "print(\"Test set: Difference in mean outcomes between unprivileged and privileged groups = %f\" % metric_dataset_debiasing_test.mean_difference())\n",
    "\n",
    "\n",
    "\n",
    "display(Markdown(\"#### Plain model - without debiasing - classification metrics\"))\n",
    "print(\"Test set: Classification accuracy = %f\" % classified_metric_nodebiasing_test.accuracy())\n",
    "TPR = classified_metric_nodebiasing_test.true_positive_rate()\n",
    "TNR = classified_metric_nodebiasing_test.true_negative_rate()\n",
    "bal_acc_nodebiasing_test = 0.5*(TPR+TNR)\n",
    "print(\"Test set: Balanced classification accuracy = %f\" % bal_acc_nodebiasing_test)\n",
    "print(\"Test set: Disparate impact = %f\" % classified_metric_nodebiasing_test.disparate_impact())\n",
    "print(\"Test set: Equal opportunity difference = %f\" % classified_metric_nodebiasing_test.equal_opportunity_difference())\n",
    "print(\"Test set: Average odds difference = %f\" % classified_metric_nodebiasing_test.average_odds_difference())\n",
    "print(\"Test set: Theil_index = %f\" % classified_metric_nodebiasing_test.theil_index())\n",
    "\n",
    "\n",
    "\n",
    "display(Markdown(\"#### Model - with debiasing - classification metrics\"))\n",
    "classified_metric_debiasing_test = ClassificationMetric(dataset_orig_test, \n",
    "                                                 dataset_debiasing_test,\n",
    "                                                 unprivileged_groups=unprivileged_groups,\n",
    "                                                 privileged_groups=privileged_groups)\n",
    "print(\"Test set: Classification accuracy = %f\" % classified_metric_debiasing_test.accuracy())\n",
    "TPR = classified_metric_debiasing_test.true_positive_rate()\n",
    "TNR = classified_metric_debiasing_test.true_negative_rate()\n",
    "bal_acc_debiasing_test = 0.5*(TPR+TNR)\n",
    "print(\"Test set: Balanced classification accuracy = %f\" % bal_acc_debiasing_test)\n",
    "print(\"Test set: Disparate impact = %f\" % classified_metric_debiasing_test.disparate_impact())\n",
    "print(\"Test set: Equal opportunity difference = %f\" % classified_metric_debiasing_test.equal_opportunity_difference())\n",
    "print(\"Test set: Average odds difference = %f\" % classified_metric_debiasing_test.average_odds_difference())\n",
    "print(\"Test set: Theil_index = %f\" % classified_metric_debiasing_test.theil_index())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "    References:\n",
    "    [1] B. H. Zhang, B. Lemoine, and M. Mitchell, \"Mitigating UnwantedBiases with Adversarial Learning,\" \n",
    "    AAAI/ACM Conference on Artificial Intelligence, Ethics, and Society, 2018."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
